from turtle import color
import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from matplotlib import font_manager as fm, rcParams
from matplotlib import rc
import os
import pandas as pd
import seaborn as sns
import argparse


s = 30
rc_ = {'figure.figsize':(8,8),'axes.labelsize': 40, 'xtick.labelsize': s, 
        'ytick.labelsize': s, 'legend.fontsize': 25}
sns.set(rc=rc_, style="darkgrid")
cblue = sns.color_palette("colorblind")[0]
cgreen = sns.color_palette("colorblind")[1]
cred = sns.color_palette("colorblind")[2]
# rc('text', usetex=True)

parser = argparse.ArgumentParser()
parser.add_argument(
    '--path',
    default='./images',
    help="path"
)
args = parser.parse_args()

# Vary p [0 1], Vary n [0 2] in p^n, Vary rmin [-5 0) 

# #####################################################################################
rmin, rmax = -1, 0
p = 0
for p in [0, 0.4]:
    # D, C = 2, max([(1-p)-p,p-(1-p)]) # 0.72
    # minmax = np.min([rmin, (rmin-rmax)*(D/C)])
    minmax = -2
    if p==0.4:
        minmax = -2.63 # Obtained from chain_walk_pvsp.py
    penalty = np.linspace(-5,0,6)
    print("penalty:", penalty)

    states = 4
    P = np.zeros((penalty.shape[0], states, states, 2)) # p, S, S, A
    P[:,3,3,:] = 1.0
    P[:,1,1,:] = 1.0
    P[:,2,2,:] = np.array([p,p])
    P[:,2,3,:] = np.array([1-p,1-p])
    P[:,0,1,:] = np.array([1-p,p])
    P[:,0,2,:] = np.array([p,1-p])
    R = rmin*np.ones((penalty.shape[0], states, states, 2)) # p, S, S, A
    R[:,[1,3],:,:] = 0.0
    R[:,0,1,0] = penalty
    R[:,0,1,1] = penalty
    V = np.zeros((penalty.shape[0], states)) # p, S
    pi = np.zeros((penalty.shape[0], states)) # p, S

    convergence = np.zeros(penalty.shape[0])
    step=0
    while True:
        step+=1
        V_pre = V.copy()
        for s in range(states):
            Vs = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(states)]).sum(axis=0)) for a in range(2)]).max(axis=0)
            V[:,s] = V[:,s] + 0.1*(Vs-V[:,s]) # Vs
        for i in range(penalty.shape[0]): 
            if np.abs(V_pre[i]-V[i]).max() <= 1e-10 and convergence[i] == 0:
                convergence[i] = step
        # print(np.abs(V_pre-V).max())
        if np.abs(V_pre-V).max() <= 1e-10:
            break
    for s in range(states):
        pi[:,s] = (np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(states)]).sum(axis=0)) for a in range(2)])).round(10).argmax(axis=0)
        pi_ = (np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(states)]).sum(axis=0)) for a in range(2)])).round(10)
        print("debug: ", s, pi_, pi_[0]==pi_[1], pi_.argmax(axis=0))

    # convergence = convergence/convergence.sum()
    convergence = (convergence - convergence.min())/(convergence.max() - convergence.min())
    print("convergence:", convergence)
    print("V:", V)
    success = np.zeros(penalty.shape[0])
    for i in range(penalty.shape[0]):
        if pi[i,0] == 1:
            success[i] = (1-p)
        else:
            success[i] = p
    print(1-success)
    # #####################################################################################
    lw = 5.0
    fig, ax = plt.subplots()
    # ax.plot(penalty, 1-success, label=r'Failure rate', lw = lw)
    # ax.plot(penalty, convergence, label=r'Total timesteps', lw = lw)
    ax.plot(penalty, convergence, label=r'Timesteps', marker="o", c=cblue,  ms = 20, lw = lw)
    ax.plot(penalty, 1-success, label=r'Failures', marker="o", c=cred, ms = 20, lw = lw)
    plt.axvline(x=minmax, color="black", label=r"Minmax", linestyle="--", lw = lw) 

    ax.legend()
    plt.xlabel("Penalty")
    fig.tight_layout()
    fig.savefig("{}/{}.pdf".format(args.path,f"convergence_{p}"), bbox_inches='tight')
    plt.show()
# #####################################################################################
p = np.linspace(0,1,1000)
delta_p_s0 = (1-p)
delta_p_s0_ = p
delta_p_s0_c = np.min([delta_p_s0,delta_p_s0_], axis=0)
C_p = delta_p_s0_c

#####################################################################################
s = 30
rc_ = {'figure.figsize':(12,8),'axes.labelsize': 30, 'xtick.labelsize': s, 
        'ytick.labelsize': s, 'legend.fontsize': 30}
sns.set(rc=rc_, style="darkgrid", palette="deep")

lw = 10.0
fig, ax = plt.subplots()
# ax.plot(p, delta_p_s0,  label=r'$s_0$', lw = lw)
# ax.plot(p, delta_p_s2,  label=r'$s_2$', lw = lw)
l1,=ax.plot(p, delta_p_s0,  label=r'$\Delta P_{s_0}(\pi_1, \pi_2)$', c=cred, lw = lw)
l2,=ax.plot(p, delta_p_s0_,  label=r'$\Delta P_{s_0}(\pi_2, \pi_1)$', c=cgreen, lw = lw)
l3,=ax.plot(p, delta_p_s2,  label=r'$\Delta P_{s_2}(\pi_1, \pi_2)$', c=cblue, lw = lw)
l4,=ax.plot(p[:-1], C[:-1],  label=r'$C$', color="black", linestyle="dashed", lw = lw)

legend1 = ax.legend([l1,l2],[r'$\Delta P_{s_0}(\pi_1, \pi_2)$', r'$\Delta P_{s_0}(\pi_2, \pi_1)$'], loc='upper center')
ax.legend([l3,l4],[r'$\Delta P_{s_2}(\pi_1, \pi_2)$', r'$C$'], loc='lower center')
plt.gca().add_artist(legend1)
plt.xlabel("p")
plt.ylabel(r'$\Delta P_s$')
# ax.xaxis.get_major_formatter().set_powerlimits((0, 1))
#ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
fig.tight_layout()
fig.savefig("{}/{}.pdf".format(args.path,"controllability"), bbox_inches='tight')
plt.show()
#####################################################################################

p = p[:-1]
C = C[:-1]

P = np.zeros((p.shape[0], 4, 4, 2)) # p, S, S, A
P[:,3,3,:] = 1.0
P[:,1,1,:] = 1.0
P[:,2,2,0] = p
P[:,2,2,1] = p
P[:,2,3,0] = 1-p
P[:,2,3,1] = 1-p
P[:,0,2,0] = p
P[:,0,2,1] = 1-p
P[:,0,1,0] = 1-p
P[:,0,1,1] = p
R = np.ones((p.shape[0], 4, 4, 2)) # p, S, A
R[:,[1,3],:,:] = 0.0
V = np.zeros((p.shape[0], 4)) # p, S
while True:
    V_pre = V.copy()
    for s in range(4):
        V[:,s] = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).max(axis=0)
    # print(np.abs(V_pre-V).max())
    if np.abs(V_pre-V).max() <= 0:
        break
D = V.max(axis=1)
print("p",p)
print("D",D)

a = (rmin) + np.zeros(p.shape[0])
b = (rmin-rmax)*(D/C)
penalty = np.min([a, b], axis=0)
b = (rmin-rmax)*D
penalty1 = np.min([a, b], axis=0)
b = (rmin-rmax)*(1/C)
penalty2 = np.min([a, b], axis=0)

#####################################################################################
lw = 5.0
fig, ax = plt.subplots()
ax.plot(p, penalty, label=r'$(R_{MIN} - R_{MAX})\frac{D}{C}$', c=cred, lw = lw)
ax.plot(p, penalty2, label=r'$(R_{MIN} - R_{MAX})\frac{1}{C}$', c=cgreen, lw = lw)
ax.plot(p, penalty1, label=r'$(R_{MIN} - R_{MAX})D$', c=cblue, lw = lw)
plt.axvline(x=0.5, ymax=0.95, color="black", label=r"$C=0$", linestyle="--", lw = lw) 
plt.axvline(x=1.0, ymax=0.95, color="black", linestyle="--", lw = lw) 

ax.legend()
plt.xlabel("p")
plt.ylabel(r'Penalty')
plt.ylim(-50, 0)
fig.tight_layout()
fig.savefig("{}/{}.pdf".format(args.path,"penalty"), bbox_inches='tight')
plt.show()
#####################################################################################

P = np.zeros((p.shape[0], 4, 4, 2)) # p, S, S, A
P[:,3,3,:] = 1.0
P[:,1,1,:] = 1.0
P[:,2,2,0] = p
P[:,2,2,1] = p
P[:,2,3,0] = 1-p
P[:,2,3,1] = 1-p
P[:,0,2,0] = p
P[:,0,2,1] = 1-p
P[:,0,1,0] = 1-p
P[:,0,1,1] = p
R = rmin*np.ones((p.shape[0], 4, 4, 2)) # p, S, A
R[:,[1,3],:,:] = 0.0


penalty -= 1e-10
R[:,0,1,:] = np.array([penalty,penalty]).T
V = np.zeros((p.shape[0], 4)) # p, S
pi = np.zeros((p.shape[0], 4)) # p, S
R1 = R.copy()
R1[:,0,1,:] = np.array([penalty1,penalty1]).T
V1 = np.zeros((p.shape[0], 4)) # p, S
pi1 = np.zeros((p.shape[0], 4)) # p, S
R2 = R.copy()
R2[:,0,1,:] = np.array([penalty2,penalty2]).T
V2 = np.zeros((p.shape[0], 4)) # p, S
pi2 = np.zeros((p.shape[0], 4)) # p, S
while True:
    V_pre = V.copy()
    V_pre1 = V1.copy()
    V_pre2 = V2.copy()
    for s in range(4):
        V[:,s] = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).max(axis=0)
        V1[:,s] = np.array([(np.array([P[:,s,s_,a]*(R1[:,s,s_,a] + V1[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).max(axis=0)
        V2[:,s] = np.array([(np.array([P[:,s,s_,a]*(R2[:,s,s_,a] + V2[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).max(axis=0)
    # print(np.abs(V_pre-V).max(), np.abs(V_pre1-V1).max(), np.abs(V_pre2-V2).max())
    if np.abs(V_pre-V).max() <= 0 and np.abs(V_pre1-V1).max() <= 0 and np.abs(V_pre2-V2).max() <= 0:
        break
for s in range(4):
    pi[:,s] = np.array([(np.array([P[:,s,s_,a]*(R[:,s,s_,a] + V[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).round(10).argmax(axis=0)
    pi1[:,s] = np.array([(np.array([P[:,s,s_,a]*(R1[:,s,s_,a] + V1[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).round(10).argmax(axis=0)
    pi2[:,s] = np.array([(np.array([P[:,s,s_,a]*(R2[:,s,s_,a] + V2[:,s_]) for s_ in range(4)]).sum(axis=0)) for a in range(2)]).round(10).argmax(axis=0)

success = np.zeros(p.shape[0])
failure = np.zeros(p.shape[0])
success1 = np.zeros(p.shape[0])
failure1 = np.zeros(p.shape[0])
success2 = np.zeros(p.shape[0])
failure2 = np.zeros(p.shape[0])
for i in range(p.shape[0]):
    if pi[i,0] == 1:
        success[i] = (1-p[i])
        failure[i] = p[i]
    else:
        success[i] = p[i]
        failure[i] = 1-p[i]
    if pi1[i,0] == 1:
        success1[i] = (1-p[i])
        failure1[i] = p[i]
    else:
        success1[i] = p[i]
        failure1[i] = 1-p[i]
    if pi2[i,0] == 1:
        success2[i] = (1-p[i])
        failure2[i] = p[i]
    else:
        success2[i] = p[i]
        failure2[i] = 1-p[i]

#####################################################################################
s = 20
rc_ = {'figure.figsize':(12,8),'axes.labelsize': 30, 'xtick.labelsize': s, 
        'ytick.labelsize': s, 'legend.fontsize': 20}
sns.set(rc=rc_, style="darkgrid", palette="deep")

lw = 5.0
fig = plt.figure()
gs = fig.add_gridspec(3, hspace=0.1)
ax = gs.subplots(sharex=True, sharey=True)
ax[0].plot(p, failure, label=r'Failure', c=cred, lw = lw)
ax[0].plot(p, success, label=r'Success', linestyle="--", c=cred, lw = lw)
ax[1].plot(p, failure2, label=r'Failure', c=cgreen, lw = lw)
ax[1].plot(p, success2, label=r'Success', linestyle="--", c=cgreen, lw = lw)
ax[2].plot(p, failure1, label=r'Failure', c=cblue, lw = lw)
ax[2].plot(p, success1, label=r'Success', linestyle="--", c=cblue, lw = lw)
for a in ax:
    a.label_outer()
    a.legend()
# ax[0].label_outer()
# ax[0].legend()

plt.xlabel("p")
fig.tight_layout()
fig.savefig("{}/{}.pdf".format(args.path,"success_failure_rates"), bbox_inches='tight')
plt.show()
